Skip to content

change QwenImageTransformer UT to batch inputs#13312

Merged
sayakpaul merged 7 commits intohuggingface:mainfrom
zhtmike:batch_qwen_ut
Mar 24, 2026
Merged

change QwenImageTransformer UT to batch inputs#13312
sayakpaul merged 7 commits intohuggingface:mainfrom
zhtmike:batch_qwen_ut

Conversation

@zhtmike
Copy link
Contributor

@zhtmike zhtmike commented Mar 23, 2026

What does this PR do?

This PR expands the inputs of QwenImage UT from a single prompt to multiple prompts. Since QwenImagePipeline already supports multi-prompt inputs, we added corresponding test coverage here.

  • Tested with following tests cases with success:
    TestQwenImageTransformer;TestQwenImageTransformerMemory;TestQwenImageTransformerTraining;TestQwenImageTransformerAttention;TestQwenImageTransformerContextParallel;TestQwenImageTransformerLoRA

  • Tested with following tests cases with failure:
    TestQwenImageTransformerLoRAHotSwap (Independent issues/bugs that also occurred on the main branch)

For ContextParallel, we keep bs=1 and will address it in another PR.

In this PR, we also add a check to skip the UT for ring attention when the SDPA backend is enabled. See the discussion in #13278.

Fixes # (issue)

  • Expand the inputs for QwenImageTransformer from a single prompt to multiple prompts in UT
  • Skip the ring attention test when the backend is SDPA (“native”).

Before submitting

Who can review?

@sayakpaul

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot. I left some further comments. LMK.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just one.

f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
)

@pytest.mark.xfail(reason="Context parallel may not support batch_size > 1")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it the case for Flux as well?

Also, let's always require get_dummy_inputs() to have batch_size. So, we can safely remove the inspect stuff from here and elsewhere.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhtmike 👀

Copy link
Contributor Author

@zhtmike zhtmike Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it the case for Flux as well?

yes. Flux works fine for bs > 2, will drop xfail once the qwenimage is fixed.

Also, let's always require get_dummy_inputs() to have batch_size. So, we can safely remove the inspect stuff from here and elsewhere.

Done. Add batch size args to newly refactored model: flux & flux2. Tests are passed

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this!

@sayakpaul
Copy link
Member

Failing test is unrelated.

@sayakpaul sayakpaul merged commit 9d4c9dc into huggingface:main Mar 24, 2026
6 of 7 checks passed
@zhtmike zhtmike deleted the batch_qwen_ut branch March 24, 2026 03:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants